
# variable to induce some tests, are not necessarily run
testdata <- F

# define vector of different subjects
subjects <- unique(PiB$Subject)

if (testdata) {
  # test whether there are the same subjects in the two data sets
  all.equal(subjects, sort(unique(BCell_data$`Subject code`)))
}

# functions for PiB ----------------------------------------------------------------------------- #
## sex
get_subjectSex <- function(subject = "01_0214_01") {
  subj_data <- PiB[PiB$Subject == subject, ]

  return(unique(subj_data$gender))
}

## age
get_subjectAge <- function(subject = "01_0214_01") {
  subj_data <- PiB[PiB$Subject == subject, ]

  return(c("age1" = subj_data[subj_data$Timepoint == 1, "Age"],
           "age3" = subj_data[subj_data$Timepoint == 3, "Age"]))
}

## PiB
get_subjectPiB <- function(subject = "01_0214_01") {
  subj_data <- PiB[PiB$Subject == subject, ]

  return(c("PiB1" = subj_data[subj_data$Timepoint == 1, "PiB"],
           "PiB3" = subj_data[subj_data$Timepoint == 3, "PiB"]))
}

# original group allocation
get_subjectGroup <- function(subject = "01_0214_01") {
  subj_data <- PiB[PiB$Subject == subject, ]

  return(unique(subj_data$`PiB Groups`))
}
subjectGroup <- c(sapply(subjects, FUN = function(x) { get_subjectGroup(subject = x) }))
names(subjectGroup) <- subjects
table(subjectGroup)
# ----------------------------------------------------------------------------------------------- #

getCellCluster <- function(id = 1:3, data = BCell_data) {
  clusterdata <- NULL
  for (i in id) {
    clusterdata <- rbind(clusterdata, data[data$Cluster_id == i, ])
  }

  return(clusterdata)
}

clusterpathways <- function(cluster_id = 14, group = 2, data = BCell_data2, imputed = TRUE) {
  cluster <- data[data$Cluster_id == cluster_id, ]
  clacc <-  cluster[cluster$PiBGroup == group, ]
  clsubjects <- unique(clacc$Subject)

  plot(NULL, type = "b", xlim = c(min(clacc$Abundance, na.rm = T), max(clacc$Abundance, na.rm = T)),
       ylim = c(0.5, max(clacc$PiB, na.rm = T)),
       ylab = "PiB", xlab = "Abundance (%)", main = paste("T1NA:", sum(is.na(clacc[clacc$Timepoint == 1, "Abundance"])), ";",
                                                          "T2NA:", sum(is.na(clacc[clacc$Timepoint == 2, "Abundance"])), ";",
                                                          "T3NA:", sum(is.na(clacc[clacc$Timepoint == 3, "Abundance"])), sep = ""))
  if (imputed) {
    for (i in 1:length(clsubjects)) {
      tmpsubject <- clacc[clacc$Subject == clsubjects[i], ]
      lines(y = tmpsubject$PiB, x = tmpsubject$Abundance, type = "l", col = i)
      lines(y = tmpsubject$PiB, x = tmpsubject$Abundance, type = "b", col = i)
    }
  } else{
    for (i in 1:length(clsubjects)) {
      tmpsubject <- clacc[clacc$Subject == clsubjects[i], ]
      lines(y = tmpsubject$PiB[c(1,3)], x = tmpsubject$Abundance[c(1,3)], type = "l", col = i)
      lines(y = tmpsubject$PiB[c(1,3)], x = tmpsubject$Abundance[c(1,3)], type = "b", col = i)
    }
  }
}

modelfitsoutput <- function(listoffits = all_mixed_normalfits[[i]], useddata = BCell_data, clusterpathimputed = F, tablecaption = "mixed effect model") {
  se <- sqrt(diag(as.matrix(vcov(listoffits))))
  oopt <- options(xtable.include.rownames = TRUE, xtable.floating = TRUE,
                  xtable.type = "latex", xtable.size = "footnotesize",
                  xtable.table.placement = "!h", xtable.sanitize.colnames.function = identity)
  on.exit(options(oopt))
  print(xtable::xtable(rbind(cbind(Est = fixef(listoffits),
                                   from = fixef(listoffits) - qnorm(.975) * se,
                                   to = fixef(listoffits) + qnorm(.975) * se,
                                   se = se,
                                   pvalue = summary(listoffits)$coefficients[,5]),
                             R2marginal = c(MuMIn::r.squaredGLMM(listoffits)[1], NA, NA, NA, NA),
                             R2conditional = c(MuMIn::r.squaredGLMM(listoffits)[2], NA, NA, NA, NA)),
                       caption = paste(tablecaption, "for cluster", i), digits = c(3, 3, 3, 3, 3, 5)))

  { par(mfrow = c(1,1))
    # plot(y = resid(listoffits), x = fitted(listoffits), xlab = "fitted", ylab = "residuals")
    clusterpathways(cluster_id = i, group = 2, data = useddata, imputed = clusterpathimputed)
    # plot(x = sort(ranef(listoffits)$Subject[,1]), ylab = "sorted random intercept", xlab = "subject")

    # x <- seq(range(getCellCluster(id = i)$Abundance, na.rm = T)[1], range(getCellCluster(id = i)$Abundance, na.rm = T)[2], length.out = 100)
    # plot(x = x, xlim = c(min(getCellCluster(id = i)$Abundance, na.rm = T), max(getCellCluster(id = i)$Abundance, na.rm = T)),
    #      ylim = c(0.5, max(getCellCluster(id = i)$PiB, na.rm = T)),
    #      y = fixef(listoffits)[1] + fixef(listoffits)[2]*x +
    #        fixef(listoffits)[3]*mean(getCellCluster(id = i)$Age, na.rm = T),
    #      type = "l", xlab = "Abundance", ylab = "PiB")
    # lines(x = x,
    #       y = fixef(listoffits)[1] + fixef(listoffits)[2]*x +
    #         fixef(listoffits)[3]*min(getCellCluster(id = i)$Age, na.rm = T), col = "red")
    # lines(x = x,
    #       y = fixef(listoffits)[1] + fixef(listoffits)[2]*x +
    #         fixef(listoffits)[3]*max(getCellCluster(id = i)$Age, na.rm = T), col = "green")
    # legend("bottomleft", c("mean age", "min age", "max age"), col = c("black", "red", "green"), pch = 16)
  }
}
# ----------------------------------------------------------------------------------------------- #


getCell2Cluster <- function(id, data) {
  clusterdata <- NULL
  for (i in id) {
    clusterdata <- rbind(clusterdata, data[data$Cluster_id == i, ])
  }

  return(clusterdata)
}


Clustersubsetfor2t <- function(clusterid = i, idata, group = 2) {
  Cell_cli <- getCell2Cluster(id = clusterid, data = idata)
  Cell_cli <- Cell_cli[Cell_cli$PiBGroup == group, ]

  pibsubjects <- unique(Cell_cli$Subject)
  subsetfor2 <- function(subj = pibsubjects[1]) {
    subjdata <- Cell_cli[Cell_cli$Subject == subj, ]

    if (!is.na(subjdata[subjdata$Timepoint == 1, "Abundance"]) & !is.na(subjdata[subjdata$Timepoint == 3, "Abundance"])) {
      return((subjdata[-which(subjdata$Timepoint == 2), ])) }

    if (is.na(subjdata[subjdata$Timepoint == 1, "Abundance"]) & !is.na(subjdata[subjdata$Timepoint == 3, "Abundance"])) {
      return((subjdata[-which(subjdata$Timepoint == 1), ])) }

    if (!is.na(subjdata[subjdata$Timepoint == 1, "Abundance"]) & is.na(subjdata[subjdata$Timepoint == 3, "Abundance"])) {
      return((subjdata[-which(subjdata$Timepoint == 3), ])) }

    return((subjdata))
  }

  cli2time_data <- Cell_cli[-c(1:dim(Cell_cli)[1]), ]
  for (i in 1:length(pibsubjects)) {
    cli2time_data <- rbind(cli2time_data, subsetfor2(subj = pibsubjects[i]))
  }

  return(cli2time_data)
}

plot_slopes <- function(inputdata = CD8TCell_data, clusterid = 2, type = "pos", # type %in% c("dim1", "pos", "neg")
                        accum = TRUE, mainpaste = "CD8TCells") {

  clusterdata <-  getCellCluster(id = clusterid, data = inputdata)
  if (accum) {
    clusterdata <- clusterdata[clusterdata$PiBGroup == 2, ]
  } else {
    clusterdata <- clusterdata[clusterdata$PiBGroup == 1, ]
  }

  cli2time_data <- Clustersubsetfor2t(clusterid = clusterid, idata = inputdata, group = ifelse(accum, 2, 1))
  if (accum) {
    cli2time_data <- cli2time_data[cli2time_data$PiBGroup == 2, ]
  } else {
    cli2time_data <- cli2time_data[cli2time_data$PiBGroup == 1, ]
  }

  subjects <- unique(cli2time_data$Subject)

  plot(NULL, xlim = c(1,3), ylim = c(0, max(clusterdata$Abundance, na.rm = T)+1),
       ylab = "Abundance (%)", xlab = "Timepoint", main = paste(mainpaste, "cluster", clusterid),
       xaxt = "n")
  axis(side = 1, at = 1:3, labels = 1:3)

  nsubj <- 0
  for (i in 1:length(subjects)) {

    subjdata <- cli2time_data[cli2time_data$Subject == subjects[i], ]
    compsubjdata <- clusterdata[clusterdata$Subject == subjects[i], ]

    if (type == "pos") {
      if (dim(subjdata)[1] == 2) {
        if (!any(is.na(subjdata$Abundance))) {
          if (subjdata$Abundance[1] <= subjdata$Abundance[2]) {
            points(y = compsubjdata$Abundance, x = compsubjdata$Timepoint, col = i)
            segments(x0 = as.numeric(subjdata$Timepoint[1]), y0 = subjdata$Abundance[1],
                     x1 = as.numeric(subjdata$Timepoint[2]), y1 = subjdata$Abundance[2], col = "gainsboro")

            nsubj <- nsubj + 1
          }
        }
      }
    }


    if (type == "neg") {
      if (dim(subjdata)[1] == 2) {
        if (!any(is.na(subjdata$Abundance))) {
          if (subjdata$Abundance[1] > subjdata$Abundance[2]) {
            points(y = compsubjdata$Abundance, x = compsubjdata$Timepoint, col = i)
            segments(x0 = as.numeric(subjdata$Timepoint[1]), y0 = subjdata$Abundance[1],
                     x1 = as.numeric(subjdata$Timepoint[2]), y1 = subjdata$Abundance[2], col = "gainsboro")

            nsubj <- nsubj + 1
          }
        }
      }
    }

    if (type == "dim1") {
      if (dim(subjdata)[1] == 1 | any(is.na(subjdata$Abundance))) {
        points(y = compsubjdata$Abundance, x = compsubjdata$Timepoint, col = i)

        nsubj <- nsubj + 1
      }
    }

  }

  legend("top", legend =  paste("# patients:", as.character(nsubj)))

}

plot_slopeheatmap <- function(inputdata = CD8TCell_data, mainpaste = "CD8TCells", saveascsv = FALSE, saveslopesascv = FALSE) {

  getcluster_acc_nonacc_slopes <- function(clusterid, inputdata) {
    clusterdata <-  getCellCluster(id = clusterid, data = inputdata)
    acc_clusterdata <- clusterdata[clusterdata$PiBGroup == 2, ]
    nonacc_clusterdata <- clusterdata[clusterdata$PiBGroup == 1, ]

    acc_cli2time_data <- Clustersubsetfor2t(clusterid = clusterid, idata = inputdata, group = 2)
    nonacc_cli2time_data <- Clustersubsetfor2t(clusterid = clusterid, idata = inputdata, group = 1)

    acc_subjects <- unique(acc_cli2time_data$Subject)
    nonacc_subjects <- unique(nonacc_cli2time_data$Subject)

    acc_subjslopes <- rep(NA, length(acc_subjects))
    for (i in 1:length(acc_subjects)) {

      subjdata <- acc_cli2time_data[acc_cli2time_data$Subject == acc_subjects[i], ]

      if (dim(subjdata)[1] == 2) {
        if (!any(is.na(subjdata$Abundance))) {
          if (!any(is.na(subjdata$Age))) {
            if (subjdata$Age[1] > subjdata$Age[2])
              stop("wrong sorted")

            acc_subjslopes[i] <- (subjdata$Abundance[2] - subjdata$Abundance[1])/(subjdata$Age[2] - subjdata$Age[1])
          }
        }
      }
    }

    nonacc_subjslopes <- rep(NA, length(nonacc_subjects))
    for (i in 1:length(nonacc_subjects)) {

      subjdata <- nonacc_cli2time_data[nonacc_cli2time_data$Subject == nonacc_subjects[i], ]

      if (dim(subjdata)[1] == 2) {
        if (!any(is.na(subjdata$Abundance))) {
          if (!any(is.na(subjdata$Age))) {
            if (subjdata$Age[1] > subjdata$Age[2])
              stop("wrong sorted")

            nonacc_subjslopes[i] <- (subjdata$Abundance[2] - subjdata$Abundance[1])/(subjdata$Age[2] - subjdata$Age[1])
          }
        }
      }
    }

    return(list(accslopes = acc_subjslopes, nonaccslopes = nonacc_subjslopes))
  }


  clusters <- as.numeric(unique(inputdata$Cluster_id))

  if (saveslopesascv) {
    acc_slopes <- sapply(1:max(clusters), FUN = function(x) {
      out <- getcluster_acc_nonacc_slopes(clusterid = x, inputdata = inputdata)$accslopes
    })
    nonacc_slopes <- sapply(1:max(clusters), FUN = function(x) {
      out <- getcluster_acc_nonacc_slopes(clusterid = x, inputdata = inputdata)$nonaccslopes
    })

    colnames(acc_slopes) <- colnames(nonacc_slopes) <- as.character(1:max(clusters))
    write.csv(acc_slopes, file = paste0("acc",mainpaste, ".csv"))
    write.csv(nonacc_slopes, file = paste0("nonacc",mainpaste, ".csv"))
  }

  averageslopes <- sapply(1:max(clusters), FUN =  function(x) {
    iout <-  getcluster_acc_nonacc_slopes(clusterid = x, inputdata = inputdata)
    return(c(mean(iout$accslopes, na.rm = TRUE), mean(iout$nonaccslopes, na.rm = TRUE))) })

  if (saveascsv) {
    write.csv(averageslopes, file = paste0(mainpaste, ".csv"))
  }

  { fields::image.plot(x = 1:max(clusters), y = 0:1, z = t(averageslopes),
                       xlab = "", ylab = "", xaxt = "n", yaxt = "n", bty = "n",
                       # col = colorspace::diverge_hcl(256, "Blue-Red"),
                       col = fields::tim.colors(n = 256),
                       main = mainpaste,
                       zlim = c(-1.4, 1.4))
    axis(side = 1, at = 1:max(clusters), labels = 1:max(clusters), las = 1)
    axis(side = 2, at = 0:1, labels = c("accumulator", "non-accumulator"), las = 2)
    abline(h = .5, col = "white", lwd = 2)
    abline(v = c(-.5, 0:26+.5), col = "white", lwd = 2)
    box(lwd = 2, col = "white")
  }
}



plot_deltas <- function(inputdata = CD8TCell_data, clusterid = 2, type = "pos", # type %in% c("dim1", "pos", "neg")
                        accum = TRUE, mainpaste = "CD8TCells") {

  clusterdata <-  getCellCluster(id = clusterid, data = inputdata)
  if (accum) {
    clusterdata <- clusterdata[clusterdata$PiBGroup == 2, ]
  } else {
    clusterdata <- clusterdata[clusterdata$PiBGroup == 1, ]
  }

  cli2time_data <- Clustersubsetfor2t(clusterid = clusterid, idata = inputdata, group = ifelse(accum, 2, 1))
  if (accum) {
    cli2time_data <- cli2time_data[cli2time_data$PiBGroup == 2, ]
  } else {
    cli2time_data <- cli2time_data[cli2time_data$PiBGroup == 1, ]
  }

  subjects <- unique(cli2time_data$Subject)

  plot(NULL, ylim = c(-0.1, max(clusterdata$Delta_PiB, na.rm = T)), xlim = c(-10, 10),
       xlab = "Delta Abundance (%)", ylab = "Delta PiB", main = paste(mainpaste, "cluster", clusterid))
  abline(v = 0, lty = 2, col = "gainsboro")
  abline(h = 0, lty = 2, col = "gainsboro")

  nsubj <- 0
  for (i in 1:length(subjects)) {

    subjdata <- cli2time_data[cli2time_data$Subject == subjects[i], ]
    compsubjdata <- clusterdata[clusterdata$Subject == subjects[i], ]

    if (type == "pos") {
      if (dim(subjdata)[1] == 2) {
        if (!any(is.na(subjdata$Abundance))) {
          if (subjdata$Abundance[1] <= subjdata$Abundance[2]) {
            points(x = subjdata$Abundance[2]-subjdata$Abundance[1], y = subjdata$PiB[2]-subjdata$PiB[1], col = i)

            nsubj <- nsubj + 1
          }
        }
      }
    }


    if (type == "neg") {
      if (dim(subjdata)[1] == 2) {
        if (!any(is.na(subjdata$Abundance))) {
          if (subjdata$Abundance[1] > subjdata$Abundance[2]) {
            points(x = subjdata$Abundance[2]-subjdata$Abundance[1], y = subjdata$PiB[2]-subjdata$PiB[1], col = i)

            nsubj <- nsubj + 1
          }
        }
      }
    }

    # if (type == "dim1") {
    #   if (dim(subjdata)[1] == 1 | any(is.na(subjdata$Abundance))) {
    #     points(x = na.omit(compsubjdata$Abundance), y = 0, col = i)
    #
    #     nsubj <- nsubj + 1
    #   }
    # }

  }

  legend("top", legend =  paste("# patients:", as.character(nsubj)))

}



